"""Terminate EC2 instances and spot instance requests orphaned by gitlab-runner."""
import argparse
import collections
import contextlib
import datetime
import sys
import typing

from boto3.session import Session
from cki_lib import chatbot
from cki_lib import logger
from cki_lib import misc
import sentry_sdk

LOGGER = logger.get_logger('cki.cki_tools.orphan_hunter_ec2')


class Ec2Instance:
    """Encapsulate an EC2 instance."""

    def __init__(
            self,
            data: typing.Any,
            ec2_client: typing.Any,
            ec2_hunter: 'Ec2Hunter',
    ) -> None:
        """Initialize instance."""
        self.data = data
        self.ec2_client = ec2_client
        self.ec2_hunter = ec2_hunter

    def is_untagged(self) -> bool:
        """Return if an instance is without any tags and owned by certain keys."""
        return not self.data.get('Tags')

    def with_key(self, key_prefix: typing.List[str]) -> bool:
        """Return if an instance is configured with certain keys."""
        return any(self.data.get('KeyName', '').startswith(p) for p in key_prefix)

    @property
    def running_time(self) -> typing.Optional[datetime.timedelta]:
        """Return how long the instance has been running."""
        launch_time: typing.Optional[datetime.datetime] = self.data.get('LaunchTime')
        if not launch_time:
            return None
        return datetime.datetime.now(datetime.timezone.utc) - launch_time

    def is_older(self, age: datetime.timedelta) -> bool:
        """Return if an instance is older than a certain age."""
        if not self.running_time or not age:
            return False
        return self.running_time > age

    def terminate(self, reason: str) -> None:
        """Terminate an EC2 instance."""
        instance_id: str = self.data['InstanceId']
        logging_env = {
            'instance_data': self.data,
            'running_time': self.running_time,
            'terminate_reason': reason,
        }
        with logger.logging_env(logging_env):
            if misc.is_production():
                state = misc.get_nested_key(self.data, 'State/Name', '')
                region = misc.get_nested_key(self.data, 'Placement/AvailabilityZone', '')[: -1]
                base_url = f'https://{region}.console.aws.amazon.com/ec2/v2/home?region={region}'
                url = f'{base_url}#InstanceDetails:instanceId={instance_id}'
                message = f'Terminating {state} EC2 instance <{url}|{self}>'
                LOGGER.info('%s', message)
                self.ec2_hunter.log_message(f'{reason} instance', message)
                self.ec2_client.terminate_instances(InstanceIds=[instance_id])
            else:
                LOGGER.info('Would terminate EC2 instance %s if in production mode', self)

    def __str__(self) -> str:
        """Return a string representation."""
        name: typing.Optional[str] = next((t['Value'] for t in self.data.get('Tags', [])
                                           if t['Key'] == 'Name'), None)
        return name or self.data['InstanceId']


class Ec2SpotInstanceRequest:
    """Encapsulate an EC2 spot instance request."""

    def __init__(
            self,
            data: typing.Any,
            ec2_client: typing.Any,
            ec2_hunter: 'Ec2Hunter',
    ) -> None:
        """Initialize instance."""
        self.data = data
        self.ec2_client = ec2_client
        self.ec2_hunter = ec2_hunter

    def with_key(self, key_prefix: typing.List[str]) -> bool:
        """Return if an instance is configured with certain keys."""
        return any(misc.get_nested_key(self.data, 'LaunchSpecification/KeyName', '').startswith(p)
                   for p in key_prefix)

    @property
    def request_age(self) -> datetime.timedelta:
        """Return the age of the EC2 spot instance request."""
        create_time: datetime.datetime = self.data.get('CreateTime')
        return datetime.datetime.now(datetime.timezone.utc) - create_time

    def is_older(self, age: datetime.timedelta) -> bool:
        """Return if a EC2 spot instance request is older than a certain age."""
        return bool(age and self.request_age > age)

    def cancel(self, reason: str) -> None:
        """Cancel an EC2 spot instance request."""
        request_id: str = self.data['SpotInstanceRequestId']
        logging_env = {
            'request_data': self.data,
            'request_age': self.request_age,
            'cancel_reason': reason,
        }
        with logger.logging_env(logging_env):
            if misc.is_production():
                state = self.data.get('State', '')
                region = misc.get_nested_key(
                    self.data, 'LaunchSpecification/Placement/AvailabilityZone', '')[: -1]
                base_url = f'https://{region}.console.aws.amazon.com/ec2/v2/home?region={region}'
                url = f'{base_url}#SpotInstancesDetails:id={request_id}'
                message = f'Canceling {state} EC2 spot instance request <{url}|{self}>'
                LOGGER.info('%s', message)
                self.ec2_hunter.log_message(f'{reason} spot instance request', message)
                self.ec2_client.cancel_spot_instance_requests(SpotInstanceRequestIds=[request_id])
            else:
                LOGGER.info('Would cancel EC2 spot instance request %s if in production mode', self)

    def __str__(self) -> str:
        """Return a string representation."""
        name: typing.Optional[str] = next((t['Value'] for t in self.data.get('Tags', [])
                                           if t['Key'] == 'Name'), None)
        return name or self.data['SpotInstanceRequestId']


class Ec2Hunter:
    """Terminate EC2 instances orphaned by gitlab-runner."""

    def __init__(
        self,
        summary_only: bool = False,
    ) -> None:
        """Initialize instance."""
        self.ec2_client = Session().client('ec2')
        self.summary_only = summary_only
        self.summary: typing.Dict[str, int] = collections.defaultdict(int)

    @staticmethod
    @contextlib.contextmanager
    def summarize(
        summary_only: bool = False,
    ) -> typing.Generator['Ec2Hunter', None, None]:
        """Optionally suppress chat logging, only posting a summary at the end."""
        try:
            ec2_hunter = Ec2Hunter(summary_only)
            yield ec2_hunter
        finally:
            ec2_hunter.send_summary()

    def instances(self) -> typing.List[Ec2Instance]:
        """Return all running/stopped EC2 instances."""
        return [
            Ec2Instance(i, self.ec2_client, self)
            for r in self.ec2_client.describe_instances(
                Filters=[{'Name': 'instance-state-name', 'Values': ['running', 'stopped']}]
            )['Reservations']
            for i in r['Instances']
        ]

    def spot_instance_requests(self) -> typing.List[Ec2SpotInstanceRequest]:
        """Return all open EC2 spot instance requests."""
        return [
            Ec2SpotInstanceRequest(r, self.ec2_client, self)
            for r in self.ec2_client.describe_spot_instance_requests(
                Filters=[{'Name': 'state', 'Values': ['open']}]
            )['SpotInstanceRequests']
        ]

    def log_message(self, resource: str, message: str) -> None:
        """Send or queue a message for the chat bot."""
        if self.summary_only:
            self.summary[resource] += 1
        else:
            self.send_message(message)

    def send_summary(self) -> None:
        """Send a summary message to the chat bot."""
        resources = ', '.join(f'{v} {k}(s)' for k, v in self.summary.items())
        if resources:
            self.send_message(f'Terminated {resources}')
        self.summary.clear()

    @staticmethod
    def send_message(message: str) -> None:
        """Send a message to the chat bot."""
        chatbot.send_message(f'👻 {message}')


def main(argv: typing.Optional[typing.List[str]] = None) -> None:
    """Terminate EC2 instances orphaned by gitlab-runner."""
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--untagged-key-prefix', default=[], action='append',
                        help='EC2 key prefix to use for the search for untagged instances')
    parser.add_argument('--untagged-age', type=misc.parse_timedelta, default='1h',
                        help='Only terminate instances older than that')
    parser.add_argument('--runaway-key-prefix', default=[], action='append',
                        help='EC2 key prefix to use for the search for runaway instances')
    parser.add_argument('--runaway-age', type=misc.parse_timedelta, default='1w',
                        help='Only terminate instances older than that')
    parser.add_argument('--spot-key-prefix', default=[], action='append',
                        help='EC2 key prefix to use for the search for spot instance requests')
    parser.add_argument('--spot-age', type=misc.parse_timedelta, default='1h',
                        help='Only cancel spot instance requests older than that')
    parser.add_argument('--summary-only', action='store_true',
                        help='Only post a summary message to chat channel')
    args = parser.parse_args(argv)

    with Ec2Hunter.summarize(args.summary_only) as ec2_hunter:
        for instance in ec2_hunter.instances():
            LOGGER.debug('Checking %s', instance)
            if (
                instance.is_untagged() and
                instance.with_key(args.untagged_key_prefix) and
                instance.is_older(args.untagged_age)
            ):
                instance.terminate('untagged')
            elif (
                instance.with_key(args.runaway_key_prefix) and
                instance.is_older(args.runaway_age)
            ):
                instance.terminate('runaway')

        for spot_instance_request in ec2_hunter.spot_instance_requests():
            LOGGER.debug('Checking %s', spot_instance_request)
            if (
                spot_instance_request.with_key(args.spot_key_prefix) and
                spot_instance_request.is_older(args.spot_age)
            ):
                spot_instance_request.cancel('old')


if __name__ == "__main__":
    misc.sentry_init(sentry_sdk)
    main(sys.argv[1:])
